Re-weighting#

PolicyEngine-UK primarily relies on the Family Resources Survey, which has known issues with non-capture of households at the bottom and top of the income distribution. To correct for this, we apply a weight modification, optimised using gradient descent to minimise survey error against a diverse selection of targeting statistics. These include:

  • Regional populations

  • Household populations

  • Population by tenure type

  • Population by Council Tax band

  • Country-level program statistics

  • UK-wide program aggregates

  • UK-wide program caseloads

The graph below shows the effect of the optimisation on each of these, compared to their starting values (under original FRS weights). All loss subfunctions improve from their starting values.

Hide code cell source
import pandas as pd
import numpy as np
import pandas as pd
import plotly.express as px

df = pd.read_csv(
    "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/no_val_split/training_log_run_1.csv.gz",
)
ldf = (
    df.groupby(["category", "epoch"])
    .sum()
    .reset_index()
    .pivot(columns="category", values="loss", index="epoch")
)
ldf /= ldf.loc[0]
ldf -= 1
ldf = ldf.reset_index().melt(id_vars=["epoch"])
import plotly.express as px

ldf["hover"] = [
    f"At epoch {epoch}, the total loss from targets <br>in the category <b>{category}</b> <br>has <b>{'risen' if value > 0 else 'fallen'}</b> by <b>{abs(value):.1%}</b>."
    for epoch, category, value in zip(ldf.epoch, ldf.category, ldf.value)
]

px.line(
    ldf, x="epoch", y="value", color="category", custom_data=[ldf.hover]
).update_traces(hovertemplate="%{customdata[0]}").update_layout(
    title="Training performance by category",
    height=600,
    width=800,
    xaxis_title="Epoch",
    yaxis_title="Loss change",
    legend_title="Category",
    yaxis_range=(-1, 0),
    yaxis_tickformat=".0%",
)

Changes to distributions#

Validation#

During initial training, we split the targets into training and validation groups (80%/20%), performing 5-fold cross-validation. The graph below shows the performance of validation metrics in each fold, as well as the average over the five folds.

Hide code cell source
df = pd.read_csv(
    "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/train_val_split/training_log.csv.gz",
    compression="gzip",
)
xdf = pd.DataFrame()
for validation_type in (True, False, "Both"):
    if isinstance(validation_type, bool):
        condition = df.validation == validation_type
    else:
        condition = df.validation | ~df.validation
    x = (
        df[condition]
        .groupby(["run_id", "epoch"])
        .loss.sum()
        .reset_index()
        .pivot(columns="run_id", values="loss", index="epoch")
    )
    x /= x.loc[0]
    x -= 1
    x = x.dropna()
    x["Average"] = x.mean(axis=1)
    x["Type"] = {
        True: "Validation",
        False: "Training",
        "Both": "Training + Validation",
    }[validation_type]
    xdf = pd.concat([xdf, x])
px.line(
    xdf,
    y=xdf.columns,
    animation_frame="Type",
    color_discrete_sequence=["lightgrey"] * 5 + ["grey"],
).update_layout(
    title="5-fold cross-validation training",
    yaxis_title="Relative loss change",
    yaxis_tickformat=".0%",
    xaxis_title="Epoch",
    legend_title="Fold",
    width=800,
    height=800,
)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[2], line 1
----> 1 df = pd.read_csv(
      2     "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/train_val_split/training_log.csv.gz",
      3     compression="gzip",
      4 )
      5 xdf = pd.DataFrame()
      6 for validation_type in (True, False, "Both"):

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:912, in read_csv(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)
    899 kwds_defaults = _refine_defaults_read(
    900     dialect,
    901     delimiter,
   (...)
    908     dtype_backend=dtype_backend,
    909 )
    910 kwds.update(kwds_defaults)
--> 912 return _read(filepath_or_buffer, kwds)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:577, in _read(filepath_or_buffer, kwds)
    574 _validate_names(kwds.get("names", None))
    576 # Create the parser.
--> 577 parser = TextFileReader(filepath_or_buffer, **kwds)
    579 if chunksize or iterator:
    580     return parser

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:1407, in TextFileReader.__init__(self, f, engine, **kwds)
   1404     self.options["has_index_names"] = kwds["has_index_names"]
   1406 self.handles: IOHandles | None = None
-> 1407 self._engine = self._make_engine(f, self.engine)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:1661, in TextFileReader._make_engine(self, f, engine)
   1659     if "b" not in mode:
   1660         mode += "b"
-> 1661 self.handles = get_handle(
   1662     f,
   1663     mode,
   1664     encoding=self.options.get("encoding", None),
   1665     compression=self.options.get("compression", None),
   1666     memory_map=self.options.get("memory_map", False),
   1667     is_text=is_text,
   1668     errors=self.options.get("encoding_errors", "strict"),
   1669     storage_options=self.options.get("storage_options", None),
   1670 )
   1671 assert self.handles is not None
   1672 f = self.handles.handle

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/common.py:716, in get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)
    713     codecs.lookup_error(errors)
    715 # open URLs
--> 716 ioargs = _get_filepath_or_buffer(
    717     path_or_buf,
    718     encoding=encoding,
    719     compression=compression,
    720     mode=mode,
    721     storage_options=storage_options,
    722 )
    724 handle = ioargs.filepath_or_buffer
    725 handles: list[BaseBuffer]

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/common.py:368, in _get_filepath_or_buffer(filepath_or_buffer, encoding, compression, mode, storage_options)
    366 # assuming storage_options is to be interpreted as headers
    367 req_info = urllib.request.Request(filepath_or_buffer, headers=storage_options)
--> 368 with urlopen(req_info) as req:
    369     content_encoding = req.headers.get("Content-Encoding", None)
    370     if content_encoding == "gzip":
    371         # Override compression based on Content-Encoding header

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/common.py:270, in urlopen(*args, **kwargs)
    264 """
    265 Lazy-import wrapper for stdlib urlopen, as that imports a big chunk of
    266 the stdlib.
    267 """
    268 import urllib.request
--> 270 return urllib.request.urlopen(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:214, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    212 else:
    213     opener = _opener
--> 214 return opener.open(url, data, timeout)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:523, in OpenerDirector.open(self, fullurl, data, timeout)
    521 for processor in self.process_response.get(protocol, []):
    522     meth = getattr(processor, meth_name)
--> 523     response = meth(req, response)
    525 return response

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:632, in HTTPErrorProcessor.http_response(self, request, response)
    629 # According to RFC 2616, "2xx" code indicates that the client's
    630 # request was successfully received, understood, and accepted.
    631 if not (200 <= code < 300):
--> 632     response = self.parent.error(
    633         'http', request, response, code, msg, hdrs)
    635 return response

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:555, in OpenerDirector.error(self, proto, *args)
    553     http_err = 0
    554 args = (dict, proto, meth_name) + args
--> 555 result = self._call_chain(*args)
    556 if result:
    557     return result

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:494, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
    492 for handler in handlers:
    493     func = getattr(handler, meth_name)
--> 494     result = func(*args)
    495     if result is not None:
    496         return result

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:747, in HTTPRedirectHandler.http_error_302(self, req, fp, code, msg, headers)
    744 fp.read()
    745 fp.close()
--> 747 return self.parent.open(new, timeout=req.timeout)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:517, in OpenerDirector.open(self, fullurl, data, timeout)
    514     req = meth(req)
    516 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method())
--> 517 response = self._open(req, data)
    519 # post-process response
    520 meth_name = protocol+"_response"

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:534, in OpenerDirector._open(self, req, data)
    531     return result
    533 protocol = req.type
--> 534 result = self._call_chain(self.handle_open, protocol, protocol +
    535                           '_open', req)
    536 if result:
    537     return result

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:494, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
    492 for handler in handlers:
    493     func = getattr(handler, meth_name)
--> 494     result = func(*args)
    495     if result is not None:
    496         return result

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:1389, in HTTPSHandler.https_open(self, req)
   1388 def https_open(self, req):
-> 1389     return self.do_open(http.client.HTTPSConnection, req,
   1390         context=self._context, check_hostname=self._check_hostname)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:1350, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args)
   1348     except OSError as err: # timeout error
   1349         raise URLError(err)
-> 1350     r = h.getresponse()
   1351 except:
   1352     h.close()

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/http/client.py:1377, in HTTPConnection.getresponse(self)
   1375 try:
   1376     try:
-> 1377         response.begin()
   1378     except ConnectionError:
   1379         self.close()

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/http/client.py:320, in HTTPResponse.begin(self)
    318 # read until we get a non-100 response
    319 while True:
--> 320     version, status, reason = self._read_status()
    321     if status != CONTINUE:
    322         break

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/http/client.py:281, in HTTPResponse._read_status(self)
    280 def _read_status(self):
--> 281     line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
    282     if len(line) > _MAXLINE:
    283         raise LineTooLong("status line")

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/socket.py:704, in SocketIO.readinto(self, b)
    702 while True:
    703     try:
--> 704         return self._sock.recv_into(b)
    705     except timeout:
    706         self._timeout_occurred = True

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/ssl.py:1242, in SSLSocket.recv_into(self, buffer, nbytes, flags)
   1238     if flags != 0:
   1239         raise ValueError(
   1240           "non-zero flags not allowed in calls to recv_into() on %s" %
   1241           self.__class__)
-> 1242     return self.read(nbytes, buffer)
   1243 else:
   1244     return super().recv_into(buffer, nbytes, flags)

File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/ssl.py:1100, in SSLSocket.read(self, len, buffer)
   1098 try:
   1099     if buffer is not None:
-> 1100         return self._sslobj.read(len, buffer)
   1101     else:
   1102         return self._sslobj.read(len)

KeyboardInterrupt: 

The below chart visualises the effect of the training process on each individual training and validation metric, by epoch.

Hide code cell source
df["rel_error"] = df.pred / df.actual - 1
df["Type"] = np.where(df.validation, "Validation", "Training")
STEP_SIZE = 50

cdf = df[df.epoch % STEP_SIZE == 0]
cdf = cdf[
    (cdf.category == "Budgetary impact")
    | (cdf.category == "UK-wide program aggregates")
]

fig = px.scatter(
    cdf,
    animation_frame="epoch",
    x="actual",
    y="rel_error",
    color="Type",
    hover_data=df.columns,
    opacity=0.2,
)
layout = dict(
    title="Target metrics",
    width=800,
    height=800,
    legend_title="Type",
    yaxis_title="Relative error",
    yaxis_tickformat=".1%",
    xaxis_tickprefix="£",
    xaxis_title="Actual value",
    yaxis_range=(-1, 1),
)
fig.update_layout(**layout)

for i, frame in enumerate(fig.frames):
    frame.layout.update(layout)
    frame.layout[
        "title"
    ] = f"Budgetary impact target metric performance at {i * STEP_SIZE:,} epochs"

for step in fig.layout.sliders[0].steps:
    step["args"][1]["frame"]["redraw"] = True

for button in fig.layout.updatemenus[0].buttons:
    button["args"][1]["frame"]["redraw"] = True

import gif
import plotly.graph_objects as go

gif.save(
    [
        gif.frame(lambda: go.Figure(data=frame.data, layout=frame.layout))()
        for frame in fig.frames
    ],
    "scatterplot.gif",
    duration=3_000 / len(fig.frames),
)

fig